/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.runners.gaia.translation.functions;



import com.bff.gaia.unified.model.pipeline.v1.RunnerApi;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;

import com.bff.gaia.unified.runners.core.construction.PTransformTranslation;

import com.bff.gaia.unified.runners.core.construction.graph.ExecutableStage;

import com.bff.gaia.unified.runners.core.construction.graph.SideInputReference;

import com.bff.gaia.unified.runners.fnexecution.state.StateRequestHandler;

import com.bff.gaia.unified.runners.fnexecution.state.StateRequestHandlers.SideInputHandler;

import com.bff.gaia.unified.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PCollectionView;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.ImmutableMap;



import java.io.ByteArrayOutputStream;

import java.io.IOException;

import java.util.ArrayList;

import java.util.Arrays;

import java.util.Map;



import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkArgument;

import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkNotNull;



/**

 * {@link StateRequestHandler} that uses {@link com.bff.gaia.unified.runners.core.SideInputHandler} to

 * access the Gaia broadcast state that represents side inputs.

 */

public class GaiaStreamingSideInputHandlerFactory implements SideInputHandlerFactory {



  // Map from side input id to global PCollection id.

  private final Map<SideInputId, PCollectionView<?>> sideInputToCollection;

  private final com.bff.gaia.unified.runners.core.SideInputHandler runnerHandler;



  /**

   * Creates a new state handler for the given stage. Note that this requires a traversal of the

   * stage itself, so this should only be called once per stage rather than once per bundle.

   */

  public static GaiaStreamingSideInputHandlerFactory forStage(

      ExecutableStage stage,

      Map<SideInputId, PCollectionView<?>> viewMapping,

      com.bff.gaia.unified.runners.core.SideInputHandler runnerHandler) {

    ImmutableMap.Builder<SideInputId, PCollectionView<?>> sideInputBuilder = ImmutableMap.builder();

    for (SideInputReference sideInput : stage.getSideInputs()) {

      SideInputId sideInputId =

          SideInputId.newBuilder()

              .setTransformId(sideInput.transform().getId())

              .setLocalName(sideInput.localName())

              .build();

      sideInputBuilder.put(

          sideInputId,

          checkNotNull(

              viewMapping.get(sideInputId),

              "No side input for %s/%s",

              sideInputId.getTransformId(),

              sideInputId.getLocalName()));

    }



    GaiaStreamingSideInputHandlerFactory factory =

        new GaiaStreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler);

    return factory;

  }



  private GaiaStreamingSideInputHandlerFactory(

      Map<SideInputId, PCollectionView<?>> sideInputToCollection,

      com.bff.gaia.unified.runners.core.SideInputHandler runnerHandler) {

    this.sideInputToCollection = sideInputToCollection;

    this.runnerHandler = runnerHandler;

  }



  @Override

  public <T, V, W extends BoundedWindow> SideInputHandler<V, W> forSideInput(

      String transformId,

      String sideInputId,

      RunnerApi.FunctionSpec accessPattern,

      Coder<T> elementCoder,

      Coder<W> windowCoder) {



    PCollectionView collectionNode =

        sideInputToCollection.get(

            SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());

    checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);



    if (PTransformTranslation.ITERABLE_SIDE_INPUT.equals(accessPattern.getUrn())) {

      @SuppressWarnings("unchecked") // T == V

      Coder<V> outputCoder = (Coder<V>) elementCoder;

      return forIterableSideInput(collectionNode, outputCoder);

    } else if (PTransformTranslation.MULTIMAP_SIDE_INPUT.equals(accessPattern.getUrn())) {

      @SuppressWarnings("unchecked") // T == KV<?, V>

      KvCoder<?, V> kvCoder = (KvCoder<?, V>) elementCoder;

      return forMultimapSideInput(collectionNode, kvCoder.getKeyCoder(), kvCoder.getValueCoder());

    } else {

      throw new IllegalArgumentException(

          String.format("Unknown side input access pattern: %s", accessPattern));

    }

  }



  private <T, W extends BoundedWindow> SideInputHandler<T, W> forIterableSideInput(

      PCollectionView<?> collection, Coder<T> elementCoder) {



    return new SideInputHandler<T, W>() {

      @Override

      public Iterable<T> get(byte[] key, W window) {

        return checkNotNull(

            (Iterable<T>) runnerHandler.getIterable(collection, window),

            "Element processed by SDK before side input is ready");

      }



      @Override

      public Coder<T> resultCoder() {

        return elementCoder;

      }

    };

  }



  private <K, V, W extends BoundedWindow> SideInputHandler<V, W> forMultimapSideInput(

      PCollectionView<?> collection, Coder<K> keyCoder, Coder<V> valueCoder) {



    return new SideInputHandler<V, W>() {

      @Override

      public Iterable<V> get(byte[] key, W window) {

        Iterable<KV<K, V>> values =

            (Iterable<KV<K, V>>) runnerHandler.getIterable(collection, window);

        ArrayList<V> result = new ArrayList<>();

        // find values for the given key

        for (KV<K, V> kv : values) {

          ByteArrayOutputStream bos = new ByteArrayOutputStream();

          try {

            keyCoder.encode(kv.getKey(), bos);

            if (Arrays.equals(key, bos.toByteArray())) {

              result.add(kv.getValue());

            }

          } catch (IOException ex) {

            throw new RuntimeException(ex);

          }

        }

        return result;

      }



      @Override

      public Coder<V> resultCoder() {

        return valueCoder;

      }

    };

  }

}